{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning: algorithms, objectives, and assumptions\n",
    "(c) Deniz Yuret 2019\n",
    "\n",
    "In this notebook we will analyze three classic learning algorithms.\n",
    "* **Perceptron:** ([Rosenblatt, 1957](https://en.wikipedia.org/wiki/Perceptron)) a neuron model trained with a simple algorithm that updates model weights using the input when the prediction is wrong.\n",
    "* **Adaline:** ([Widrow and Hoff, 1960](https://en.wikipedia.org/wiki/ADALINE)) a neuron model trained with a simple algorithm that updates model weights using the error multiplied by the input (aka least mean square (LMS), delta learning rule, or the Widrow-Hoff rule).\n",
    "* **Softmax classification:** ([Cox, 1958](https://en.wikipedia.org/wiki/Multinomial_logistic_regression)) a multiclass generalization of the logistic regression model from statistics (aka multinomial LR, softmax regression, maxent classifier etc.).\n",
    "\n",
    "We look at these learners from different perspectives:\n",
    "* **Algorithm:** First we ask only **how** the learner works, i.e. how it changes after observing each example.\n",
    "* **Objectives:** Next we ask **what** objective guides the algorithm, whether it is optimizing a particular objective function, and whether we can use a generic *optimization algorithm* instead.\n",
    "* **Assumptions:** Finally we ask **why** we think this algorithm makes sense, what prior assumptions does this imply and whether we can use *probabilistic inference* for optimal learning. (TODO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "72"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using Knet, MLDatasets, Statistics, LinearAlgebra, Random\n",
    "ARRAY = Array{Float32}\n",
    "ENV[\"COLUMNS\"] = 72"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "xtrn, ytrn = MNIST.traindata(eltype(ARRAY)) \n",
    "xtst, ytst = MNIST.testdata(eltype(ARRAY))\n",
    "xtrn, xtst = ARRAY(mat(xtrn)), ARRAY(mat(xtst))\n",
    "onehot(y) = (m=zeros(eltype(ARRAY),10,length(y)); for i in 1:length(y); m[y[i]==0 ? 10 : y[i],i]=1; end; ARRAY(m))\n",
    "ytrn, ytst = onehot(ytrn), onehot(ytst);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "784×60000 Array{Float32,2}\n",
      "10×60000 Array{Float32,2}\n",
      "784×10000 Array{Float32,2}\n",
      "10×10000 Array{Float32,2}\n"
     ]
    }
   ],
   "source": [
    "println.(summary.((xtrn, ytrn, xtst, ytst)));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(60000, 10000, 784, 10)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "NTRN,NTST,XDIM,YDIM = size(xtrn,2), size(xtst,2), size(xtrn,1), size(ytrn,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10×784 Array{Float32,2}:\n",
       " -0.01215     0.628198  -0.740396  …  -0.630068   0.280267  -0.940353\n",
       " -0.863612   -1.92136    1.39318       0.879314   0.247981   0.348589\n",
       " -0.0308054   0.268047   0.605304      0.610452  -0.276971   0.757551\n",
       " -2.79762    -0.582162   0.268582     -0.607423   1.61061    1.35039\n",
       " -0.153525   -0.196723   0.282439      0.898576  -0.985867   0.146549\n",
       "  0.841244    0.166082   0.484519  …  -0.893788   0.588566  -0.330542\n",
       " -0.0370356  -0.675147  -0.373608      1.81982    0.687714  -0.405147\n",
       " -2.33272    -0.9257    -0.384133      0.338714  -0.491783   0.110337\n",
       " -1.09294    -0.211454  -0.236161     -1.79776    1.93232    0.155101\n",
       "  0.0560821   0.667874   1.21667       0.959533   2.63739   -0.616039"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Model weights\n",
    "w = ARRAY(randn(YDIM,XDIM))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10×60000 Array{Float32,2}:\n",
       " 12.0088     12.6426     6.83006   …  10.5872    16.2395    4.72857\n",
       " -0.977543    8.03122   -0.653524      0.724494  12.2396    1.17001\n",
       " -6.0617     -3.13611    5.28498      -0.447501  -8.94851   5.51493\n",
       " 14.084      12.459      7.1651       15.484     10.3361    6.92621\n",
       "  3.26313    -2.96441  -16.6912        6.37751   -7.10098   2.64587\n",
       "  6.07298   -15.7869    -0.260126  …  -2.72658   -3.23432   0.259716\n",
       "  1.69763     8.43725    6.54435       4.119      6.72252  -3.45105\n",
       "  2.6587     -3.62504    6.70296       0.625211  -1.21695   9.72003\n",
       " -5.79982    -7.12983  -13.4711       -7.95972    6.02882   4.25041\n",
       " -2.86784    -5.86244   -3.21701       0.428129  -1.10273  -0.971755"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Class scores\n",
    "w * xtrn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1×60000 Adjoint{Int64,Array{Int64,1}}:\n",
       " 4  1  4  8  9  8  1  4  9  1  8  …  4  4  1  1  8  3  1  8  4  1  8"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Predictions\n",
    "[ argmax(w * xtrn[:,i]) for i in 1:NTRN ]'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1×60000 Adjoint{Int64,Array{Int64,1}}:\n",
       " 5  10  4  1  9  2  1  3  1  4  3  …  8  9  2  9  5  1  8  3  5  6  8"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Correct answers\n",
    "[ argmax(ytrn[:,i]) for i in 1:NTRN ]'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.07835, 0.0769)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Accuracy\n",
    "acc(w,x,y) = mean(argmax(w * x, dims=1) .== argmax(y, dims=1))\n",
    "acc(w,xtrn,ytrn), acc(w,xtst,ytst)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Algorithms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "train (generic function with 2 methods)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Training loop\n",
    "function train(algo,x,y,T=2^20)\n",
    "    w = ARRAY(zeros(size(y,1),size(x,1)))\n",
    "    nexamples = size(x,2)\n",
    "    nextprint = 1\n",
    "    for t = 1:T\n",
    "        i = rand(1:nexamples)\n",
    "        algo(w, x[:,i], y[:,i])  # <== this is where w is updated\n",
    "        if t == nextprint\n",
    "            println((iter=t, accuracy=acc(w,x,y), wnorm=norm(w)))\n",
    "            nextprint = min(2t,T)\n",
    "        end\n",
    "    end\n",
    "    w\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Perceptron"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "perceptron (generic function with 1 method)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function perceptron(w,x,y)\n",
    "    guess = argmax(w * x)\n",
    "    class = argmax(y)\n",
    "    if guess != class\n",
    "        w[class,:] .+= x\n",
    "        w[guess,:] .-= x\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(iter = 1, accuracy = 0.0993, wnorm = 16.183395f0)\n",
      "(iter = 2, accuracy = 0.14571666666666666, wnorm = 18.594425f0)\n",
      "(iter = 4, accuracy = 0.12638333333333332, wnorm = 21.674007f0)\n",
      "(iter = 8, accuracy = 0.24383333333333335, wnorm = 30.619133f0)\n",
      "(iter = 16, accuracy = 0.32233333333333336, wnorm = 38.3449f0)\n",
      "(iter = 32, accuracy = 0.40315, wnorm = 48.85941f0)\n",
      "(iter = 64, accuracy = 0.48985, wnorm = 61.05146f0)\n",
      "(iter = 128, accuracy = 0.4500166666666667, wnorm = 87.74915f0)\n",
      "(iter = 256, accuracy = 0.6148, wnorm = 111.76928f0)\n",
      "(iter = 512, accuracy = 0.69185, wnorm = 140.70638f0)\n",
      "(iter = 1024, accuracy = 0.7066333333333333, wnorm = 184.62056f0)\n",
      "(iter = 2048, accuracy = 0.7945333333333333, wnorm = 221.71835f0)\n",
      "(iter = 4096, accuracy = 0.8195, wnorm = 273.88263f0)\n",
      "(iter = 8192, accuracy = 0.8366166666666667, wnorm = 333.44574f0)\n",
      "(iter = 16384, accuracy = 0.8612833333333333, wnorm = 397.18716f0)\n",
      "(iter = 32768, accuracy = 0.80065, wnorm = 484.5039f0)\n",
      "(iter = 65536, accuracy = 0.8796833333333334, wnorm = 586.9317f0)\n",
      "(iter = 131072, accuracy = 0.8663166666666666, wnorm = 706.42633f0)\n",
      "(iter = 262144, accuracy = 0.8926166666666666, wnorm = 864.61755f0)\n",
      "(iter = 524288, accuracy = 0.8894333333333333, wnorm = 1056.9808f0)\n",
      "(iter = 1048576, accuracy = 0.86, wnorm = 1330.5365f0)\n",
      "  8.709958 seconds (8.80 M allocations: 4.544 GiB, 11.45% gc time)\n"
     ]
    }
   ],
   "source": [
    "# (iter = 1048576, accuracy = 0.8950333333333333, wnorm = 1321.2463f0) in 7 secs\n",
    "@time wperceptron = train(perceptron,xtrn,ytrn);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adaline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "adaline (generic function with 1 method)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function adaline(w,x,y; lr=0.0001)\n",
    "    error = w * x - y\n",
    "    w .-= lr * error * x'\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(iter = 1, accuracy = 0.10441666666666667, wnorm = 0.0010784594f0)\n",
      "(iter = 2, accuracy = 0.10441666666666667, wnorm = 0.0018062098f0)\n",
      "(iter = 4, accuracy = 0.14796666666666666, wnorm = 0.0023069703f0)\n",
      "(iter = 8, accuracy = 0.1129, wnorm = 0.0031464882f0)\n",
      "(iter = 16, accuracy = 0.1797, wnorm = 0.0055435537f0)\n",
      "(iter = 32, accuracy = 0.25838333333333335, wnorm = 0.008615678f0)\n",
      "(iter = 64, accuracy = 0.1327, wnorm = 0.0146190785f0)\n",
      "(iter = 128, accuracy = 0.28781666666666667, wnorm = 0.02509437f0)\n",
      "(iter = 256, accuracy = 0.6364, wnorm = 0.040473588f0)\n",
      "(iter = 512, accuracy = 0.67785, wnorm = 0.065088175f0)\n",
      "(iter = 1024, accuracy = 0.7113166666666667, wnorm = 0.10240222f0)\n",
      "(iter = 2048, accuracy = 0.7436166666666667, wnorm = 0.1651955f0)\n",
      "(iter = 4096, accuracy = 0.7975333333333333, wnorm = 0.25398782f0)\n",
      "(iter = 8192, accuracy = 0.7949, wnorm = 0.35792586f0)\n",
      "(iter = 16384, accuracy = 0.8295, wnorm = 0.46056572f0)\n",
      "(iter = 32768, accuracy = 0.8382666666666667, wnorm = 0.56489986f0)\n",
      "(iter = 65536, accuracy = 0.8461833333333333, wnorm = 0.6659388f0)\n",
      "(iter = 131072, accuracy = 0.8494166666666667, wnorm = 0.7811052f0)\n",
      "(iter = 262144, accuracy = 0.85155, wnorm = 0.92290044f0)\n",
      "(iter = 524288, accuracy = 0.8517833333333333, wnorm = 1.0947992f0)\n",
      "(iter = 1048576, accuracy = 0.8462833333333334, wnorm = 1.2931899f0)\n",
      " 30.894758 seconds (8.44 M allocations: 65.193 GiB, 8.65% gc time)\n"
     ]
    }
   ],
   "source": [
    "# (iter = 1048576, accuracy = 0.8523, wnorm = 1.2907721f0) in 31 secs with lr=0.0001\n",
    "@time wadaline = train(adaline,xtrn,ytrn);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Softmax classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "softmax (generic function with 1 method)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function softmax(w,x,y; lr=0.01)\n",
    "    probs = exp.(w * x)\n",
    "    probs ./= sum(probs)\n",
    "    error = probs - y\n",
    "    w .-= lr * error * x'\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(iter = 1, accuracy = 0.09871666666666666, wnorm = 0.09837641f0)\n",
      "(iter = 2, accuracy = 0.09871666666666666, wnorm = 0.18767525f0)\n",
      "(iter = 4, accuracy = 0.10088333333333334, wnorm = 0.24415228f0)\n",
      "(iter = 8, accuracy = 0.19958333333333333, wnorm = 0.27368906f0)\n",
      "(iter = 16, accuracy = 0.15525, wnorm = 0.44634444f0)\n",
      "(iter = 32, accuracy = 0.4477833333333333, wnorm = 0.5210072f0)\n",
      "(iter = 64, accuracy = 0.5075, wnorm = 0.81101984f0)\n",
      "(iter = 128, accuracy = 0.5968333333333333, wnorm = 1.2767634f0)\n",
      "(iter = 256, accuracy = 0.7307, wnorm = 1.9046426f0)\n",
      "(iter = 512, accuracy = 0.72075, wnorm = 2.7310653f0)\n",
      "(iter = 1024, accuracy = 0.8428666666666667, wnorm = 3.651138f0)\n",
      "(iter = 2048, accuracy = 0.8595666666666667, wnorm = 4.6969028f0)\n",
      "(iter = 4096, accuracy = 0.86895, wnorm = 5.87642f0)\n",
      "(iter = 8192, accuracy = 0.89265, wnorm = 7.225823f0)\n",
      "(iter = 16384, accuracy = 0.9005666666666666, wnorm = 8.691754f0)\n",
      "(iter = 32768, accuracy = 0.9052166666666667, wnorm = 10.411587f0)\n",
      "(iter = 65536, accuracy = 0.9171833333333334, wnorm = 12.447153f0)\n",
      "(iter = 131072, accuracy = 0.9156, wnorm = 14.749239f0)\n",
      "(iter = 262144, accuracy = 0.9223166666666667, wnorm = 17.719744f0)\n",
      "(iter = 524288, accuracy = 0.9248333333333333, wnorm = 21.654827f0)\n",
      "(iter = 1048576, accuracy = 0.9267166666666666, wnorm = 26.467497f0)\n",
      " 35.911268 seconds (8.83 M allocations: 65.288 GiB, 8.64% gc time)\n"
     ]
    }
   ],
   "source": [
    "# (iter = 1048576, accuracy = 0.9242166666666667, wnorm = 26.523603f0) in 32 secs with lr=0.01\n",
    "@time wsoftmax = train(softmax,xtrn,ytrn);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Objectives"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "optimize (generic function with 1 method)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Training via optimization\n",
    "function optimize(loss,x,y; lr=0.1, iters=2^20)\n",
    "    w = Param(ARRAY(zeros(size(y,1),size(x,1))))\n",
    "    nexamples = size(x,2)\n",
    "    nextprint = 1\n",
    "    for t = 1:iters\n",
    "        i = rand(1:nexamples)\n",
    "        L = @diff loss(w, x[:,i], y[:,i])\n",
    "        ∇w = grad(L,w)\n",
    "        w .-= lr * ∇w\n",
    "        if t == nextprint\n",
    "            println((iter=t, accuracy=acc(w,x,y), wnorm=norm(w)))\n",
    "            nextprint = min(2t,iters)\n",
    "        end\n",
    "    end\n",
    "    w\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Perceptron minimizes the score difference between the correct class and the prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "perceptronloss (generic function with 1 method)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function perceptronloss(w,x,y)\n",
    "    score = w * x\n",
    "    guess = argmax(score)\n",
    "    class = argmax(y)\n",
    "    score[guess] - score[class]\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "AutoGrad.@zerograd argmax(x;dims=:) # to be fixed in AutoGrad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(iter = 1, accuracy = 0.09751666666666667, wnorm = 17.802845f0)\n",
      "(iter = 2, accuracy = 0.11656666666666667, wnorm = 19.03113f0)\n",
      "(iter = 4, accuracy = 0.0977, wnorm = 23.478838f0)\n",
      "(iter = 8, accuracy = 0.1641, wnorm = 25.38451f0)\n",
      "(iter = 16, accuracy = 0.21983333333333333, wnorm = 35.316483f0)\n",
      "(iter = 32, accuracy = 0.2434, wnorm = 46.611935f0)\n",
      "(iter = 64, accuracy = 0.43803333333333333, wnorm = 67.370865f0)\n",
      "(iter = 128, accuracy = 0.57065, wnorm = 87.302666f0)\n",
      "(iter = 256, accuracy = 0.63065, wnorm = 114.32892f0)\n",
      "(iter = 512, accuracy = 0.73915, wnorm = 145.60088f0)\n",
      "(iter = 1024, accuracy = 0.73595, wnorm = 173.54251f0)\n",
      "(iter = 2048, accuracy = 0.7659166666666667, wnorm = 218.51064f0)\n",
      "(iter = 4096, accuracy = 0.80205, wnorm = 279.03473f0)\n",
      "(iter = 8192, accuracy = 0.8284833333333333, wnorm = 334.0729f0)\n",
      "(iter = 16384, accuracy = 0.8654666666666667, wnorm = 413.40176f0)\n",
      "(iter = 32768, accuracy = 0.8414166666666667, wnorm = 495.76038f0)\n",
      "(iter = 65536, accuracy = 0.8770666666666667, wnorm = 592.1692f0)\n",
      "(iter = 131072, accuracy = 0.8697166666666667, wnorm = 708.2958f0)\n",
      "(iter = 262144, accuracy = 0.8744166666666666, wnorm = 855.1693f0)\n",
      "(iter = 524288, accuracy = 0.8869833333333333, wnorm = 1062.1263f0)\n",
      "(iter = 1048576, accuracy = 0.8941833333333333, wnorm = 1326.4813f0)\n",
      " 92.653049 seconds (206.13 M allocations: 77.691 GiB, 10.61% gc time)\n"
     ]
    }
   ],
   "source": [
    "# (iter = 1048576, accuracy = 0.8908833333333334, wnorm = 1322.4888f0) in 62 secs with lr=1\n",
    "@time wperceptron2 = optimize(perceptronloss,xtrn,ytrn,lr=1);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adaline minimizes the squared error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "quadraticloss (generic function with 1 method)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function quadraticloss(w,x,y)\n",
    "    0.5 * sum(abs2, w * x - y)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(iter = 1, accuracy = 0.09736666666666667, wnorm = 0.0009124075f0)\n",
      "(iter = 2, accuracy = 0.11793333333333333, wnorm = 0.0011352332f0)\n",
      "(iter = 4, accuracy = 0.14368333333333333, wnorm = 0.0019526349f0)\n",
      "(iter = 8, accuracy = 0.14278333333333335, wnorm = 0.003078862f0)\n",
      "(iter = 16, accuracy = 0.158, wnorm = 0.0047181337f0)\n",
      "(iter = 32, accuracy = 0.12331666666666667, wnorm = 0.008533138f0)\n",
      "(iter = 64, accuracy = 0.3027666666666667, wnorm = 0.015150967f0)\n",
      "(iter = 128, accuracy = 0.22838333333333333, wnorm = 0.026578516f0)\n",
      "(iter = 256, accuracy = 0.48543333333333333, wnorm = 0.042260554f0)\n",
      "(iter = 512, accuracy = 0.6769166666666667, wnorm = 0.06585408f0)\n",
      "(iter = 1024, accuracy = 0.7186, wnorm = 0.10273933f0)\n",
      "(iter = 2048, accuracy = 0.7477166666666667, wnorm = 0.16474141f0)\n",
      "(iter = 4096, accuracy = 0.7714166666666666, wnorm = 0.2536855f0)\n",
      "(iter = 8192, accuracy = 0.8083833333333333, wnorm = 0.3555469f0)\n",
      "(iter = 16384, accuracy = 0.8217333333333333, wnorm = 0.46040946f0)\n",
      "(iter = 32768, accuracy = 0.8362166666666667, wnorm = 0.5608353f0)\n",
      "(iter = 65536, accuracy = 0.8478166666666667, wnorm = 0.6659025f0)\n",
      "(iter = 131072, accuracy = 0.8467, wnorm = 0.77567065f0)\n",
      "(iter = 262144, accuracy = 0.8509833333333333, wnorm = 0.9172911f0)\n",
      "(iter = 524288, accuracy = 0.8456666666666667, wnorm = 1.0924813f0)\n",
      "(iter = 1048576, accuracy = 0.8495, wnorm = 1.2925024f0)\n",
      "105.148895 seconds (194.68 M allocations: 138.529 GiB, 11.46% gc time)\n"
     ]
    }
   ],
   "source": [
    "# (iter = 1048576, accuracy = 0.8498333333333333, wnorm = 1.2882874f0) in 79 secs with lr=0.0001\n",
    "@time wadaline2 = optimize(quadraticloss,xtrn,ytrn,lr=0.0001);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Softmax classifier maximizes the probabilities of correct answers\n",
    "(or minimizes negative log likelihood, aka cross-entropy or softmax loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "negloglik (generic function with 1 method)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function negloglik(w,x,y)\n",
    "    probs = exp.(w * x)\n",
    "    probs = probs / sum(probs)\n",
    "    class = argmax(y)\n",
    "    -log(probs[class])\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(iter = 1, accuracy = 0.11236666666666667, wnorm = 0.05294757f0)\n",
      "(iter = 2, accuracy = 0.0994, wnorm = 0.11926413f0)\n",
      "(iter = 4, accuracy = 0.18665, wnorm = 0.16653569f0)\n",
      "(iter = 8, accuracy = 0.2708, wnorm = 0.2787828f0)\n",
      "(iter = 16, accuracy = 0.25911666666666666, wnorm = 0.37261048f0)\n",
      "(iter = 32, accuracy = 0.35436666666666666, wnorm = 0.51743394f0)\n",
      "(iter = 64, accuracy = 0.49038333333333334, wnorm = 0.78246355f0)\n",
      "(iter = 128, accuracy = 0.61035, wnorm = 1.2495269f0)\n",
      "(iter = 256, accuracy = 0.7281833333333333, wnorm = 1.8780376f0)\n",
      "(iter = 512, accuracy = 0.7697166666666667, wnorm = 2.7066066f0)\n",
      "(iter = 1024, accuracy = 0.8248, wnorm = 3.670761f0)\n",
      "(iter = 2048, accuracy = 0.8471333333333333, wnorm = 4.6900873f0)\n",
      "(iter = 4096, accuracy = 0.8789166666666667, wnorm = 5.87028f0)\n",
      "(iter = 8192, accuracy = 0.88915, wnorm = 7.179222f0)\n",
      "(iter = 16384, accuracy = 0.8981666666666667, wnorm = 8.627161f0)\n",
      "(iter = 32768, accuracy = 0.9089666666666667, wnorm = 10.263525f0)\n",
      "(iter = 65536, accuracy = 0.911, wnorm = 12.389948f0)\n",
      "(iter = 131072, accuracy = 0.9183333333333333, wnorm = 14.704919f0)\n",
      "(iter = 262144, accuracy = 0.9220166666666667, wnorm = 17.689854f0)\n",
      "(iter = 524288, accuracy = 0.92385, wnorm = 21.412092f0)\n",
      "(iter = 1048576, accuracy = 0.9279833333333334, wnorm = 26.516264f0)\n",
      "133.294565 seconds (353.79 M allocations: 118.099 GiB, 9.62% gc time)\n"
     ]
    }
   ],
   "source": [
    "# (iter = 1048576, accuracy = 0.9283833333333333, wnorm = 26.593485f0) in 120 secs with lr=0.01\n",
    "@time wsoftmax2 = optimize(negloglik,xtrn,ytrn,lr=0.01);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "kernelspec": {
   "display_name": "Julia 1.5.0",
   "language": "julia",
   "name": "julia-1.5"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.5.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
